Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DETR] Remove timm hardcoded logic in modeling files #29038

Merged

Conversation

amyeroberts
Copy link
Collaborator

@amyeroberts amyeroberts commented Feb 15, 2024

What does this PR do?

Certain models use timm's create_model to load their backbone.

In future, all models should use load_backbone to create the backbone, removing the need for the conditional timm logic. Removing this from existing models isn't possible, because it changes the weight names for the backbone as the backbone is now loaded as a TimmBackbone class i.e. existing checkpoints wouldn't compatible.

This PR makes it possible to configure the timm backbone loaded completely through the model config, removing the hard-coded values in the modeling files. So, for users, it's the same as-if load_backbone was being used.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts amyeroberts force-pushed the remove-timm-in-modeling-files branch from bc8d155 to 0a38d16 Compare February 19, 2024 15:04
@amyeroberts amyeroberts force-pushed the remove-timm-in-modeling-files branch from 4d524f2 to 33c89b9 Compare March 7, 2024 19:11
@@ -141,23 +161,6 @@ def backward(context, grad_output):
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None


if is_scipy_available():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moves these above the MultiScaleDeformableAttentionFunction definition - better matching library patterns

raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
if use_timm_backbone and backbone_kwargs is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These replicate the defaults that are used to load a timm backbone in the modeling file. This PR makes it possible to configure the timm backbone loaded, using the standard backbone API, the defaults here are for backwards compatibility

@@ -354,17 +354,20 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we can't remove the timm logic here and use load_backbone instead. When using load_backbone a timm model is loaded as a TimmBackbone class. This means, the loaded weight names are different from using the create_model call here. For backwards compatibility - being able to load existing checkpoints - we need to leave as-is.

Instead - to be compatible with the backbone API and remove the hard-coding, we allow specifying of the backbone behaviour through backbone_kwargs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense!

@amyeroberts amyeroberts changed the title Remove timm in modeling files [DETR] Remove timm hardcoded logic in modeling files Mar 8, 2024
@amyeroberts amyeroberts mentioned this pull request Mar 8, 2024
9 tasks
@amyeroberts amyeroberts requested a review from LysandreJik March 12, 2024 20:53
@amyeroberts amyeroberts mentioned this pull request Mar 18, 2024
8 tasks
@huggingface huggingface deleted a comment from github-actions bot Apr 8, 2024
@amyeroberts amyeroberts force-pushed the remove-timm-in-modeling-files branch from dbc1355 to 51eb6d3 Compare April 8, 2024 14:11
@amyeroberts amyeroberts requested review from ArthurZucker and removed request for LysandreJik April 8, 2024 14:12
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think that makes sense! No problem for me, thanks for working on this @amyeroberts

@@ -354,17 +354,20 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense!

Comment on lines -357 to +363
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe add a few comments here to explain what's happening for posterity

Comment on lines -425 to +438
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM +1 on the comments for posterity! 🤗 sorry for being slow here

@amyeroberts amyeroberts force-pushed the remove-timm-in-modeling-files branch from 51eb6d3 to 80b32cc Compare April 26, 2024 11:00
@amyeroberts amyeroberts merged commit aafa7ce into huggingface:main Apr 26, 2024
20 checks passed
@amyeroberts amyeroberts deleted the remove-timm-in-modeling-files branch April 26, 2024 15:55
itazap pushed a commit that referenced this pull request May 14, 2024
* Enable instantiating model with pretrained backbone weights

* Clarify pretrained import

* Use load_backbone instead

* Add backbone_kwargs to config

* Fix up

* Add tests

* Tidy up

* Enable instantiating model with pretrained backbone weights

* Update tests so backbone checkpoint isn't passed in

* Clarify pretrained import

* Update configs - docs and validation check

* Update src/transformers/utils/backbone_utils.py

Co-authored-by: Arthur <[email protected]>

* Clarify exception message

* Update config init in tests

* Add test for when use_timm_backbone=True

* Use load_backbone instead

* Add use_timm_backbone to the model configs

* Add backbone_kwargs to config

* Pass kwargs to constructors

* Draft

* Fix tests

* Add back timm - weight naming

* More tidying up

* Whoops

* Tidy up

* Handle when kwargs are none

* Update tests

* Revert test changes

* Deformable detr test - don't use default

* Don't mutate; correct model attributes

* Add some clarifying comments

* nit - grammar is hard

---------

Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants